{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e77b74d8-b804-48e9-8f0d-cdf5e8eab980",
   "metadata": {},
   "source": [
    "# GSM8K Performance with smaller scaled CoT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5af440f5-9a49-4805-843d-8cebd2924f3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "import re\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b4394258-b66c-4d37-9284-85a7585063fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tenacity import (\n",
    "    retry,\n",
    "    stop_after_attempt,\n",
    "    wait_chain,\n",
    "    wait_fixed\n",
    ") \n",
    "\n",
    "@retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] +\n",
    "                       [wait_fixed(5) for i in range(2)] +\n",
    "                       [wait_fixed(10)]))\n",
    "def completion_with_backoff(**kwargs):\n",
    "    return openai.Completion.create(**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f2be408c-4589-44c2-b478-0d813c2a880d",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai.api_key = \"sk-\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "173afa58-a394-4236-92b1-428f1b1a693c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset gsm8k (/Users/yaofu/.cache/huggingface/datasets/gsm8k/main/1.1.0/37bfb08b1d4fcbb01f06b03d9e1ef5f1fcbd4d3af3d08842c50d7305091285ba)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4e2cfdcd17454e2dbdafaba4d7afa89b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "gsm8k = load_dataset('gsm8k', 'main')\n",
    "validation_index = np.load('../gsm8k/lib_prompt/validation_index.npy')\n",
    "validation_data = gsm8k['train'].select(validation_index)\n",
    "gsm8k_test = gsm8k['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "85b66f4a-79ff-4d67-87a9-5d4b29bd17a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_simple_4_cases = open('../gsm8k/lib_prompt/prompt_simple_4_cases.txt').read()\n",
    "prompt_simple_4_cases_ao = open('../gsm8k/lib_prompt/prompt_simple_4_cases_ao.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c3f649a5-3791-4beb-ac20-386fd9d97a0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_answer(pred_str, ans_str):\n",
    "    pattern = '\\d*\\.?\\d+'\n",
    "    pred = re.findall(pattern, pred_str)\n",
    "    if(len(pred) >= 1):\n",
    "        # print(pred_str)\n",
    "        pred = pred[-1]\n",
    "        gold = re.findall(pattern, ans_str)\n",
    "        # print(ans_str)\n",
    "        gold = gold[-1]\n",
    "        return pred == gold\n",
    "    else: return False\n",
    "\n",
    "def parse_pred_ans(filename):\n",
    "    with open(filename) as fd: lines = fd.readlines()\n",
    "    am, a = None, None\n",
    "    num_q, acc = 0, 0\n",
    "    current_mode = 'none'\n",
    "    questions = []\n",
    "    ans_pred = []\n",
    "    ans_gold = []\n",
    "    for l in lines:\n",
    "        if(l.startswith('Q: ')):\n",
    "            if(am is not None and a is not None):\n",
    "                questions.append(q)\n",
    "                ans_pred.append(am)\n",
    "                ans_gold.append(a)\n",
    "                if(test_answer(am, a)):\n",
    "                    acc += 1\n",
    "            current_mode = 'q'\n",
    "            q = l\n",
    "            num_q += 1\n",
    "        elif(l.startswith('A_model:')):\n",
    "            current_mode = 'am'\n",
    "            am = l\n",
    "        elif(l.startswith('A:')):\n",
    "            current_mode = 'a'\n",
    "            a = l\n",
    "        else:\n",
    "            if(current_mode == 'q'): q += l\n",
    "            elif(current_mode == 'am'): am += l\n",
    "            elif(current_mode == 'a'): a += l\n",
    "            else:\n",
    "                raise ValueError(current_mode)\n",
    "                \n",
    "    questions.append(q)\n",
    "    ans_pred.append(am)\n",
    "    ans_gold.append(a)\n",
    "    if(test_answer(am, a)):\n",
    "        acc += 1\n",
    "    print('num_q %d correct %d ratio %.4f' % (num_q, acc, float(acc / num_q)))\n",
    "    return questions, ans_pred, ans_gold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "71bee908-02cc-41fb-acaa-53e9af736427",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_finished(ans_model):\n",
    "    if('answer is' in ans_model): return True\n",
    "    else: return False\n",
    "\n",
    "def extract_ans(ans_model):\n",
    "    ans_model = ans_model.split('\\n')\n",
    "    ans = []\n",
    "    residual = []\n",
    "    for li, al in enumerate(ans_model):\n",
    "        ans.append(al)\n",
    "        if('answer is' in al):\n",
    "            break\n",
    "    residual = list(ans_model[li + 1:])\n",
    "    ans = '\\n'.join(ans)\n",
    "    residual = '\\n'.join(residual)\n",
    "    return ans, residual"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18b80b9d-d317-4345-8e35-bea126ba0144",
   "metadata": {},
   "source": [
    "# Curie CoT, Acc 2.05"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4fb3f9a6-04a9-4140-8c1d-a6eb91d82070",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [16:09<00:00,  1.36it/s]\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('outputs/test_curie_simple_4_cases.txt', 'w') as fd:\n",
    "    for q, a in tqdm(zip(gsm8k_test['question'], gsm8k_test['answer']), \n",
    "                               total=len(gsm8k_test['question'])):\n",
    "        prompt_q = prompt_simple_4_cases + '\\nQuestion: ' + q + '\\n'\n",
    "        \n",
    "        \n",
    "        response = completion_with_backoff(model=\"text-curie-001\", \n",
    "                                            prompt=prompt_q, \n",
    "                                            temperature=0, \n",
    "                                            max_tokens=256)\n",
    "        ans_model = response['choices'][0]['text']\n",
    "        \n",
    "        ans_, residual = extract_ans(ans_model)\n",
    "            \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c2a4cbe1-f46b-423f-b547-7e7aaab1c6ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 1319 correct 27 ratio 0.0205\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('outputs/test_curie_simple_4_cases.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b9372c7-6875-4b3d-a16a-ad09112ed1e7",
   "metadata": {},
   "source": [
    "# Curie AO Acc 3.26"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "6747ee24-66ce-4031-b82c-f23a58d5e856",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [07:23<00:00,  2.98it/s]\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('outputs/test_curie_simple_4_cases_ao.txt', 'w') as fd:\n",
    "    for q, a in tqdm(zip(gsm8k_test['question'], gsm8k_test['answer']), \n",
    "                               total=len(gsm8k_test['question'])):\n",
    "        prompt_q = prompt_simple_4_cases_ao + '\\nQuestion: ' + q + '\\n'\n",
    "        \n",
    "        \n",
    "        response = completion_with_backoff(model=\"text-curie-001\", \n",
    "                                            prompt=prompt_q, \n",
    "                                            temperature=0, \n",
    "                                            max_tokens=256)\n",
    "        ans_model = response['choices'][0]['text']\n",
    "        \n",
    "        ans_, residual = extract_ans(ans_model)\n",
    "            \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "eb7b1079-5a83-4ed0-9eac-97a46771cb5b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 1319 correct 43 ratio 0.0326\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('outputs/test_curie_simple_4_cases_ao.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a444de0a-1f42-435a-a759-699b6378d01c",
   "metadata": {},
   "source": [
    "# Ada CoT, Acc 2.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d7caabfc-7540-4881-bd32-a6c200891763",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [08:07<00:00,  2.71it/s]\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('outputs/test_ada_simple_4_cases.txt', 'w') as fd:\n",
    "    for q, a in tqdm(zip(gsm8k_test['question'], gsm8k_test['answer']), \n",
    "                               total=len(gsm8k_test['question'])):\n",
    "        prompt_q = prompt_simple_4_cases + '\\nQuestion: ' + q + '\\n'\n",
    "        \n",
    "        \n",
    "        response = completion_with_backoff(model=\"text-ada-001\", \n",
    "                                            prompt=prompt_q, \n",
    "                                            temperature=0, \n",
    "                                            max_tokens=256)\n",
    "        ans_model = response['choices'][0]['text']\n",
    "        \n",
    "        ans_, residual = extract_ans(ans_model)\n",
    "            \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "f5eb04a8-689e-4b04-beb0-14c119d08b24",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 1319 correct 33 ratio 0.0250\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('outputs/test_ada_simple_4_cases.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8d3aba8-4c89-41ad-9f06-362f1af1da95",
   "metadata": {},
   "source": [
    "# Ada AO Acc 2.05"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "f60f9e66-7dc3-4b70-b0c3-ba7775a3a24b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [06:48<00:00,  3.23it/s]\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('outputs/test_ada_simple_4_cases_ao.txt', 'w') as fd:\n",
    "    for q, a in tqdm(zip(gsm8k_test['question'], gsm8k_test['answer']), \n",
    "                               total=len(gsm8k_test['question'])):\n",
    "        prompt_q = prompt_simple_4_cases_ao + '\\nQuestion: ' + q + '\\n'\n",
    "        \n",
    "        \n",
    "        response = completion_with_backoff(model=\"text-ada-001\", \n",
    "                                            prompt=prompt_q, \n",
    "                                            temperature=0, \n",
    "                                            max_tokens=256)\n",
    "        ans_model = response['choices'][0]['text']\n",
    "        \n",
    "        ans_, residual = extract_ans(ans_model)\n",
    "            \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "100faf04-3abc-42d6-b87a-eecfe806b109",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 1319 correct 27 ratio 0.0205\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('outputs/test_ada_simple_4_cases_ao.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e88a80f-a18c-4363-9767-ab0cda740f46",
   "metadata": {},
   "source": [
    "# Babbage CoT, Acc 1.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "153ca3e1-1a96-4928-a7cf-0c69f841e642",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [11:20<00:00,  1.94it/s]\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('outputs/test_babbage_simple_4_cases.txt', 'w') as fd:\n",
    "    for q, a in tqdm(zip(gsm8k_test['question'], gsm8k_test['answer']), \n",
    "                               total=len(gsm8k_test['question'])):\n",
    "        prompt_q = prompt_simple_4_cases + '\\nQuestion: ' + q + '\\n'\n",
    "        \n",
    "        \n",
    "        response = completion_with_backoff(model=\"text-babbage-001\", \n",
    "                                            prompt=prompt_q, \n",
    "                                            temperature=0, \n",
    "                                            max_tokens=256)\n",
    "        ans_model = response['choices'][0]['text']\n",
    "        \n",
    "        ans_, residual = extract_ans(ans_model)\n",
    "            \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8858a1a3-2878-4d73-bc69-a282b78e7251",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 1319 correct 23 ratio 0.0174\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('outputs/test_babbage_simple_4_cases.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a482e665-9e05-4c0b-81e9-3c3810e4a0ce",
   "metadata": {},
   "source": [
    "# Babbage AO Acc 2.65"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "11619856-aa87-4ce2-b7ad-cde05716dd79",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [05:24<00:00,  4.07it/s]\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "with open('outputs/test_babbage_simple_4_cases_ao.txt', 'w') as fd:\n",
    "    for q, a in tqdm(zip(gsm8k_test['question'], gsm8k_test['answer']), \n",
    "                               total=len(gsm8k_test['question'])):\n",
    "        prompt_q = prompt_simple_4_cases_ao + '\\nQuestion: ' + q + '\\n'\n",
    "        \n",
    "        \n",
    "        response = completion_with_backoff(model=\"text-babbage-001\", \n",
    "                                            prompt=prompt_q, \n",
    "                                            temperature=0, \n",
    "                                            max_tokens=256)\n",
    "        ans_model = response['choices'][0]['text']\n",
    "        \n",
    "        ans_, residual = extract_ans(ans_model)\n",
    "            \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))\n",
    "        i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "f5b2b658-48ce-4907-98d4-22ed3388e9e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 1319 correct 35 ratio 0.0265\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('outputs/test_babbage_simple_4_cases_ao.txt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
